package com.atguigu.day11;

import com.atguigu.bean.WaterSensor;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.annotation.DataTypeHint;
import org.apache.flink.table.annotation.FunctionHint;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.types.Row;

import static org.apache.flink.table.api.Expressions.$;
import static org.apache.flink.table.api.Expressions.call;

public class Flink02_Fun_UDF_AggFun {
    public static void main(String[] args) throws Exception {
        //1.获取流的执行环境
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();

        env.setParallelism(1);

        //2.从文件读取数据
        SingleOutputStreamOperator<WaterSensor> waterSensorDStream = env
                .socketTextStream("localhost", 9999)
                .map(new MapFunction<String, WaterSensor>() {
                    @Override
                    public WaterSensor map(String value) throws Exception {
                        String[] split = value.split(",");
                        return new WaterSensor(split[0], Long.parseLong(split[1]), Integer.parseInt(split[2]));
                    }
                });

        //3.获取表的执行环境
        StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);

        //4.将流转为表
        Table table = tableEnv.fromDataStream(waterSensorDStream);

        //TODO 不注册直接使用自定义函数
//        table
//                .groupBy($("id"))
//                .select($("id"),call(MyUDAF.class,$("vc")))
//                .execute().print();

        //TODO 先注册再使用
        tableEnv.createTemporarySystemFunction("myUDAF", MyUDAF.class);

//        table
//                .groupBy($("id"))
//                .select($("id"),call("myUDAF", $("vc")))
//                .execute().print();
        tableEnv.executeSql("select " +
                "id," +
                "myUDAF(vc) as vcAvg " +
                "from "+table+" " +
                "group by id").print();

    }

    //自定义一个聚合函数（多进一出）根据id求vc的平均值
    public static class Myacc{
        public Integer sum;
        public Integer count;
    }

    public static class MyUDAF extends AggregateFunction<Double,Myacc>{

        //初始化累加器
        @Override
        public Myacc createAccumulator() {
            Myacc myacc = new Myacc();
            myacc.sum = 0;
            myacc.count = 0;
            return myacc;
        }

        //累加操作，更新累加器
        public void accumulate(Myacc acc,Integer value){
            acc.sum += value;
            acc.count++;
        }

        //返回最终结果
        @Override
        public Double getValue(Myacc accumulator) {
            if (accumulator.count==0){
                return null;
            }else {
                return accumulator.sum*1D/accumulator.count;
            }
        }
    }

}
